module GradientAscent

using ProgressLogging

import Main.AlgorithmTools:TimePoint


export gradientAscent



function gradientAscent(
    state, gradient;
    iterations=Int(1e3), 
    step=t -> 1e-2, 
    with_progress=true, 
    kwargs...
  )

  state = deepcopy(state)


  trajectory = TimePoint{typeof(state)}[]

  function pushTimePoint!(time, state)
    push!(trajectory, TimePoint{typeof(state)}(time, deepcopy(state)))
  end

  pushTimePoint!(0, state)


  progress_id = get(kwargs, :progress_id, ProgressLogging.ROOTID)

  current_progress, total_progress = get(kwargs, :progress_partition, (0, 1))
  total_progress -= current_progress

  progress_name = get(kwargs, :process_name, "Gradient Ascent")

  for t ∈ 1:iterations

    gradientₜ = gradient(state)

    for i ∈ axes(state, 1)
      state[i] .+= step(t) * gradientₜ[i]
    end

    pushTimePoint!(t, state)


    if (with_progress)
      @info ProgressLogging.Progress(
        progress_id,
        current_progress + total_progress * t / iterations;
        name=progress_name
      )
    end
    
  end


  trajectory

end



end